{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Estimators"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/estimator\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rILQuAiiRlI7"
      },
      "source": [
        "> Warning: TensorFlow 2.15 included the final release of the `tf-estimator` package. Estimators will not be available in TensorFlow 2.16 or after. See the [migration guide](https://www.tensorflow.org/guide/migrate/migrating_estimator) for more information about how to convert off of Estimators."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEinLJt2Uowq"
      },
      "source": [
        "This document introduces `tf.estimator`—a high-level TensorFlow\n",
        "API. Estimators encapsulate the following actions:\n",
        "\n",
        "* Training\n",
        "* Evaluation\n",
        "* Prediction\n",
        "* Export for serving\n",
        "\n",
        "TensorFlow implements several pre-made Estimators. Custom estimators are still suported, but mainly as a backwards compatibility measure. **Custom estimators should not be used for new code**.  All Estimators—pre-made or custom ones—are classes based on the `tf.estimator.Estimator` class.\n",
        "\n",
        "For a quick example, try [Estimator tutorials](../tutorials/estimator/linear.ipynb). For an overview of the API design, check the [white paper](https://arxiv.org/abs/1708.02637)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KLdnqg4G2bmz"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cXRQ6mRM5gk0"
      },
      "outputs": [],
      "source": [
        "!pip install -U tensorflow_datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J_-C9ty22dkD"
      },
      "outputs": [],
      "source": [
        "import tempfile\n",
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wg5zbBliQvNL"
      },
      "source": [
        "## Advantages\n",
        "\n",
        "Similar to a `tf.keras.Model`, an `estimator` is a model-level abstraction. The `tf.estimator` provides some capabilities currently still under development for `tf.keras`. These are:\n",
        "\n",
        "  * Parameter server based training\n",
        "  * Full [TFX](http://tensorflow.org/tfx) integration"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQ8fQYt_VD5E"
      },
      "source": [
        "## Estimators Capabilities\n",
        "Estimators provide the following benefits:\n",
        "\n",
        "* You can run Estimator-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Estimator-based models on CPUs, GPUs, or TPUs without recoding your model.\n",
        "* Estimators provide a safe distributed training loop that controls how and when to:    \n",
        "    * Load data\n",
        "    * Handle exceptions\n",
        "    * Create checkpoint files and recover from failures\n",
        "    * Save summaries for TensorBoard\n",
        "\n",
        "When writing an application with Estimators, you must separate the data input pipeline from the model. This separation simplifies experiments with different datasets."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQ2PsufpgIpM"
      },
      "source": [
        "## Using pre-made Estimators\n",
        "\n",
        "Pre-made Estimators enable you to work at a much higher conceptual level than the base TensorFlow APIs. You no longer have to worry about creating the computational graph or sessions since Estimators handle all the \"plumbing\" for you. Furthermore, pre-made Estimators let you experiment with different model architectures by making only minimal code changes.  `tf.estimator.DNNClassifier`, for example, is a pre-made Estimator class that trains classification models based on dense, feed-forward neural networks.\n",
        "\n",
        "A TensorFlow program relying on a pre-made Estimator typically consists of the following four steps:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mIJPPe26gQpF"
      },
      "source": [
        "### 1. Write an input functions\n",
        "\n",
        "For example, you might create one function to import the training set and another function to import the test set. Estimators expect their inputs to be formatted as a pair of objects:\n",
        "\n",
        "* A dictionary in which the keys are feature names and the values are Tensors (or SparseTensors) containing the corresponding feature data\n",
        "* A Tensor containing one or more labels\n",
        "\n",
        "The `input_fn` should return a `tf.data.Dataset` that yields pairs in that format. \n",
        "\n",
        "For example, the following code builds a `tf.data.Dataset` from the Titanic dataset's `train.csv` file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7fl_C5d6hEl3"
      },
      "outputs": [],
      "source": [
        "def train_input_fn():\n",
        "  titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
        "  titanic = tf.data.experimental.make_csv_dataset(\n",
        "      titanic_file, batch_size=32,\n",
        "      label_name=\"survived\")\n",
        "  titanic_batches = (\n",
        "      titanic.cache().repeat().shuffle(500)\n",
        "      .prefetch(tf.data.AUTOTUNE))\n",
        "  return titanic_batches"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CjyrQGb3mCcp"
      },
      "source": [
        "The `input_fn` is executed in a `tf.Graph` and can also directly return a `(features_dics, labels)` pair containing graph tensors, but this is error prone outside of simple cases like returning constants."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yJYjWUMxgTnq"
      },
      "source": [
        "### 2. Define the feature columns.\n",
        "\n",
        "Each `tf.feature_column` identifies a feature name, its type, and any input pre-processing. \n",
        "\n",
        "For example, the following snippet creates three feature columns.\n",
        "\n",
        "- The first uses the `age` feature directly as a floating-point input. \n",
        "- The second uses the `class` feature as a categorical input.\n",
        "- The third uses the `embark_town` as a categorical input, but uses the `hashing trick` to avoid the need to enumerate the options, and to set the number of options.\n",
        "\n",
        "For further information, check the [feature columns tutorial](https://www.tensorflow.org/tutorials/keras/feature_columns)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFd8Dnrmhjhr"
      },
      "outputs": [],
      "source": [
        "age = tf.feature_column.numeric_column('age')\n",
        "cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) \n",
        "embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UIjqAozjgXdr"
      },
      "source": [
        "### 3. Instantiate the relevant pre-made Estimator.\n",
        "\n",
        "For example, here's a sample instantiation of a pre-made Estimator named `LinearClassifier`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CDOx6lZVoVB8"
      },
      "outputs": [],
      "source": [
        "model_dir = tempfile.mkdtemp()\n",
        "model = tf.estimator.LinearClassifier(\n",
        "    model_dir=model_dir,\n",
        "    feature_columns=[embark, cls, age],\n",
        "    n_classes=2\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QGl9oYuFoYj6"
      },
      "source": [
        "For more information, you can go the [linear classifier tutorial](https://www.tensorflow.org/tutorials/estimator/linear)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sXNBeY-oVxGQ"
      },
      "source": [
        "### 4. Call a training, evaluation, or inference method.\n",
        "\n",
        "All Estimators provide `train`, `evaluate`, and `predict` methods.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iGaJKkmVBgo2"
      },
      "outputs": [],
      "source": [
        "model = model.train(input_fn=train_input_fn, steps=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CXkivCNq0vfH"
      },
      "outputs": [],
      "source": [
        "result = model.evaluate(train_input_fn, steps=10)\n",
        "\n",
        "for key, value in result.items():\n",
        "  print(key, \":\", value)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CPLD8n4CLVi_"
      },
      "outputs": [],
      "source": [
        "for pred in model.predict(train_input_fn):\n",
        "  for key, value in pred.items():\n",
        "    print(key, \":\", value)\n",
        "  break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cbmrm9pFg5vo"
      },
      "source": [
        "### Benefits of pre-made Estimators\n",
        "\n",
        "Pre-made Estimators encode best practices, providing the following benefits:\n",
        "\n",
        "* Best practices for determining where different parts of the computational graph should run, implementing strategies on a single machine or on a\n",
        "    cluster.\n",
        "*   Best practices for event (summary) writing and universally useful\n",
        "    summaries.\n",
        "\n",
        "If you don't use pre-made Estimators, you must implement the preceding features yourself."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oIaPjYgnZdn6"
      },
      "source": [
        "## Custom Estimators\n",
        "\n",
        "The heart of every Estimator—whether pre-made or custom—is its *model function*, `model_fn`, which is a method that builds graphs for training, evaluation, and prediction. When you are using a pre-made Estimator, someone else has already implemented the model function. When relying on a custom Estimator, you must write the model function yourself.\n",
        "\n",
        "> Note: A custom `model_fn` will still run in 1.x-style graph mode. This means there is no eager execution and no automatic control dependencies. You should plan to migrate away from `tf.estimator` with custom `model_fn`. The alternative APIs are `tf.keras` and `tf.distribute`. If you still need an `Estimator` for some part of your training you can use the `tf.keras.estimator.model_to_estimator` converter to create an `Estimator` from a `keras.Model`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P7aPNnXUbN4j"
      },
      "source": [
        "## Create an Estimator from a Keras model\n",
        "\n",
        "You can convert existing Keras models to Estimators with `tf.keras.estimator.model_to_estimator`. This is helpful if you want to modernize your model code, but your training pipeline still requires Estimators. \n",
        "\n",
        "Instantiate a Keras MobileNet V2 model and compile the model with the optimizer, loss, and metrics to train with:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XE6NMcuGeDOP"
      },
      "outputs": [],
      "source": [
        "keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(\n",
        "    input_shape=(160, 160, 3), include_top=False)\n",
        "keras_mobilenet_v2.trainable = False\n",
        "\n",
        "estimator_model = tf.keras.Sequential([\n",
        "    keras_mobilenet_v2,\n",
        "    tf.keras.layers.GlobalAveragePooling2D(),\n",
        "    tf.keras.layers.Dense(1)\n",
        "])\n",
        "\n",
        "# Compile the model\n",
        "estimator_model.compile(\n",
        "    optimizer='adam',\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A3hcxzcEfYfX"
      },
      "source": [
        "Create an `Estimator` from the compiled Keras model. The initial model state of the Keras model is preserved in the created `Estimator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UCSSifirfyHk"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8jRNRVb_fzGT"
      },
      "source": [
        "Treat the derived `Estimator` as you would with any other `Estimator`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rv9xJk51e1fB"
      },
      "outputs": [],
      "source": [
        "IMG_SIZE = 160  # All images will be resized to 160x160\n",
        "\n",
        "def preprocess(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = (image/127.5) - 1\n",
        "  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fw8OjwujVBkc"
      },
      "outputs": [],
      "source": [
        "def train_input_fn(batch_size):\n",
        "  data = tfds.load('cats_vs_dogs', as_supervised=True)\n",
        "  train_data = data['train']\n",
        "  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)\n",
        "  return train_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JMb0cuy0gbTi"
      },
      "source": [
        "To train, call Estimator's train function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4JsvMp8Jge80"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvr_rAzngY9v"
      },
      "source": [
        "Similarly, to evaluate, call the Estimator's evaluate function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kVNPqysQgYR2"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5HeTOvCYbjZb"
      },
      "source": [
        "For more details, please refer to the documentation for `tf.keras.estimator.model_to_estimator`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zGG1tOM0L6iM"
      },
      "source": [
        "## Saving object-based checkpoints with Estimator\n",
        "\n",
        "Estimators by default save checkpoints with variable names rather than the object graph described in the [Checkpoint guide](checkpoint.ipynb). `tf.train.Checkpoint` will read name-based checkpoints, but variable names may change when moving parts of a model outside of the Estimator's `model_fn`. For forwards compatibility saving object-based checkpoints makes it easier to train a model inside an Estimator and then use it outside of one."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-8AMJeueNyoM"
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf_compat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W5JbCEUGY-Xo"
      },
      "outputs": [],
      "source": [
        "def toy_dataset():\n",
        "  inputs = tf.range(10.)[:, None]\n",
        "  labels = inputs * 5. + tf.range(5.)[None, :]\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    dict(x=inputs, y=labels)).repeat().batch(2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gTZbsIRCZnCU"
      },
      "outputs": [],
      "source": [
        "class Net(tf.keras.Model):\n",
        "  \"\"\"A simple linear model.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(Net, self).__init__()\n",
        "    self.l1 = tf.keras.layers.Dense(5)\n",
        "\n",
        "  def call(self, x):\n",
        "    return self.l1(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6fQsBzJQN2y"
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode):\n",
        "  net = Net()\n",
        "  opt = tf.keras.optimizers.Adam(0.1)\n",
        "  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n",
        "                             optimizer=opt, net=net)\n",
        "  with tf.GradientTape() as tape:\n",
        "    output = net(features['x'])\n",
        "    loss = tf.reduce_mean(tf.abs(output - features['y']))\n",
        "  variables = net.trainable_variables\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "  return tf.estimator.EstimatorSpec(\n",
        "    mode,\n",
        "    loss=loss,\n",
        "    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n",
        "                      ckpt.step.assign_add(1)),\n",
        "    # Tell the Estimator to save \"ckpt\" in an object-based format.\n",
        "    scaffold=tf_compat.train.Scaffold(saver=ckpt))\n",
        "\n",
        "tf.keras.backend.clear_session()\n",
        "est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n",
        "est.train(toy_dataset, steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tObYHnrrb_mL"
      },
      "source": [
        "`tf.train.Checkpoint` can then load the Estimator's checkpoints from its `model_dir`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q6IP3Y_wb-fs"
      },
      "outputs": [],
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "ckpt = tf.train.Checkpoint(\n",
        "  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n",
        "ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n",
        "ckpt.step.numpy()  # From est.train(..., steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dk5wWyuMpuHx"
      },
      "source": [
        "## SavedModels from Estimators\n",
        "\n",
        "Estimators export SavedModels through `tf.Estimator.export_saved_model`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9KQq5qzpzbK"
      },
      "outputs": [],
      "source": [
        "input_column = tf.feature_column.numeric_column(\"x\")\n",
        "\n",
        "estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])\n",
        "\n",
        "def input_fn():\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    ({\"x\": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)\n",
        "estimator.train(input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y9qCa6J6FVS5"
      },
      "source": [
        "To save an `Estimator` you need to create a `serving_input_receiver`. This function builds a part of a `tf.Graph` that parses the raw data received by the SavedModel. \n",
        "\n",
        "The `tf.estimator.export` module contains functions to help build these `receivers`.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJ4PJ-Cl4060"
      },
      "source": [
        "The following code builds a receiver, based on the `feature_columns`, that accepts serialized `tf.Example` protocol buffers, which are often used with [tf-serving](https://tensorflow.org/serving)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lnmsmGOQFPED"
      },
      "outputs": [],
      "source": [
        "tmpdir = tempfile.mkdtemp()\n",
        "\n",
        "serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "  tf.feature_column.make_parse_example_spec([input_column]))\n",
        "\n",
        "estimator_base_path = os.path.join(tmpdir, 'from_estimator')\n",
        "estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q7XtbLMDaie2"
      },
      "source": [
        "You can also load and run that model, from python:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c_BUBBNB1UH9"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(estimator_path)\n",
        "\n",
        "def predict(x):\n",
        "  example = tf.train.Example()\n",
        "  example.features.feature[\"x\"].float_list.value.extend([x])\n",
        "  return imported.signatures[\"predict\"](\n",
        "    examples=tf.constant([example.SerializeToString()]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C1ylWZCQ1ahG"
      },
      "outputs": [],
      "source": [
        "print(predict(1.5))\n",
        "print(predict(3.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_IrCCm0-isqA"
      },
      "source": [
        "`tf.estimator.export.build_raw_serving_input_receiver_fn` allows you to create input functions which take raw tensors rather than `tf.train.Example`s."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nO0hmFCRoIll"
      },
      "source": [
        "## Using `tf.distribute.Strategy` with Estimator (Limited support)\n",
        "\n",
        "`tf.estimator` is a distributed training TensorFlow API that originally supported the async parameter server approach. `tf.estimator` now supports `tf.distribute.Strategy`. If you're using `tf.estimator`, you can change to distributed training with very few changes to your code. With this, Estimator users can now do synchronous distributed training on multiple GPUs and multiple workers, as well as use TPUs. This support in Estimator is, however, limited. Check out the [What's supported now](#estimator_support) section below for more details.\n",
        "\n",
        "Using `tf.distribute.Strategy` with Estimator is slightly different than in the Keras case. Instead of using `strategy.scope`, now you pass the strategy object into the `RunConfig` for the Estimator.\n",
        "\n",
        "You can refer to the [distributed training guide](distributed_training.ipynb) for more information.\n",
        "\n",
        "Here is a snippet of code that shows this with a premade Estimator `LinearRegressor` and `MirroredStrategy`:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oGFY5nW_B3YU"
      },
      "outputs": [],
      "source": [
        "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
        "config = tf.estimator.RunConfig(\n",
        "    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)\n",
        "regressor = tf.estimator.LinearRegressor(\n",
        "    feature_columns=[tf.feature_column.numeric_column('feats')],\n",
        "    optimizer='SGD',\n",
        "    config=config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n6eSfLN5RGY8"
      },
      "source": [
        "Here, you use a premade Estimator, but the same code works with a custom Estimator as well. `train_distribute` determines how training will be distributed, and `eval_distribute` determines how evaluation will be distributed. This is another difference from Keras where you use the same strategy for both training and eval.\n",
        "\n",
        "Now you can train and evaluate this Estimator with an input function:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ky2ve2PB3YP"
      },
      "outputs": [],
      "source": [
        "def input_fn():\n",
        "  dataset = tf.data.Dataset.from_tensors(({\"feats\":[1.]}, [1.]))\n",
        "  return dataset.repeat(1000).batch(10)\n",
        "regressor.train(input_fn=input_fn, steps=10)\n",
        "regressor.evaluate(input_fn=input_fn, steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgaU9xQSSk2x"
      },
      "source": [
        "Another difference to highlight here between Estimator and Keras is the input handling. In Keras, each batch of the dataset is split automatically across the multiple replicas. In Estimator, however, you do not perform automatic batch splitting, nor automatically shard the data across different workers. You have full control over how you want your data to be distributed across workers and devices, and you must provide an `input_fn` to specify how to distribute your data.\n",
        "\n",
        "Your `input_fn` is called once per worker, thus giving one dataset per worker. Then one batch from that dataset is fed to one replica on that worker, thereby consuming N batches for N replicas on 1 worker. In other words, the dataset returned by the `input_fn` should provide batches of size `PER_REPLICA_BATCH_SIZE`. And the global batch size for a step can be obtained as `PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync`.\n",
        "\n",
        "When performing multi-worker training, you should either split your data across the workers, or shuffle with a random seed on each. You can check an example of how to do this in the [Multi-worker training with Estimator](../tutorials/distribute/multi_worker_with_estimator.ipynb) tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3ieQKfWZhhL"
      },
      "source": [
        "And similarly, you can use multi worker and parameter server strategies as well. The code remains the same, but you need to use `tf.estimator.train_and_evaluate`, and set `TF_CONFIG` environment variables for each binary running in your cluster."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A_lvUsSLZzVg"
      },
      "source": [
        "<a name=\"estimator_support\"></a>\n",
        "\n",
        "### What's supported now?\n",
        "\n",
        "There is limited support for training with Estimator using all strategies except `TPUStrategy`. Basic training and evaluation should work, but a number of advanced features such as `v1.train.Scaffold` do not. There may also be a number of bugs in this integration and there are no plans to actively improve this support (the focus is on Keras and custom training loop support). If at all possible, you should prefer to use `tf.distribute` with those APIs instead.\n",
        "\n",
        "| Training API  \t| MirroredStrategy \t| TPUStrategy \t| MultiWorkerMirroredStrategy \t| CentralStorageStrategy \t| ParameterServerStrategy \t|\n",
        "|:---------------\t|:------------------\t|:-------------\t|:-----------------------------\t|:------------------------\t|:-------------------------\t|\n",
        "| Estimator API \t| Limited support        \t| Not supported   \t| Limited support                   \t| Limited support              \t| Limited support               \t|\n",
        "\n",
        "### Examples and tutorials\n",
        "\n",
        "Here are some end-to-end examples that show how to use various strategies with Estimator:\n",
        "\n",
        "1. The [Multi-worker Training with Estimator tutorial](../tutorials/distribute/multi_worker_with_estimator.ipynb) shows how you can train with multiple workers using `MultiWorkerMirroredStrategy` on the MNIST dataset.\n",
        "2. An end-to-end example of [running multi-worker training with distribution strategies](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy) in `tensorflow/ecosystem` using Kubernetes templates. It starts with a Keras model and converts it to an Estimator using the `tf.keras.estimator.model_to_estimator` API.\n",
        "3. The official [ResNet50](https://github.com/tensorflow/models/blob/master/official/vision/image_classification/resnet_imagenet_main.py) model, which can be trained using either `MirroredStrategy` or `MultiWorkerMirroredStrategy`."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L",
        "KLdnqg4G2bmz",
        "Wg5zbBliQvNL",
        "yQ8fQYt_VD5E",
        "jQ2PsufpgIpM",
        "mIJPPe26gQpF",
        "yJYjWUMxgTnq",
        "UIjqAozjgXdr",
        "sXNBeY-oVxGQ",
        "cbmrm9pFg5vo",
        "oIaPjYgnZdn6",
        "P7aPNnXUbN4j",
        "zGG1tOM0L6iM",
        "Dk5wWyuMpuHx",
        "nO0hmFCRoIll",
        "A_lvUsSLZzVg"
      ],
      "name": "estimator.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
